import pandas as pd
import numpy as np
from scipy import stats
import networkx as nx
import matplotlib.pyplot as plt
import copy
import torch
import torchvision
from nltk.corpus import wordnet as wn
import torchvision.transforms as transforms
import torchvision.models
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torchvision.transforms as tvtf
import tqdm
import csv
import pickle
import os
import timm
from sklearn.model_selection import train_test_split
from hierarchy.hierarchy import Hierarchy
from hierarchy.inference_rules import *
import clip
from utils.csv_utils import *
import argparse
from hierarchy.threshold_algorithms import GraphDistanceLoss

config_parser = parser = argparse.ArgumentParser(description='config', add_help=False)
parser.add_argument('-lli', '--low-limit', default=None, type=int,
                    metavar='N',
                    help='at which index of the timm models list to start')
parser.add_argument('-hli', '--high-limit', default=None, type=int,
                    metavar='N',
                    help='at which index of the timm models list to finish')
parser.add_argument('-n', '--cal-set-n', default=5000, type=int,
                    metavar='N',
                    help='Calibration set size (n). Recommended 5000 for Imagenet, 10000 for iNat21')
parser.add_argument('-re', '--repeats', default=1000, type=int,
                    metavar='N',
                    help='Number of repetitions for each model and alpha value')
parser.add_argument('-inat', '--inat', default=False, action='store_true',
                    help='Eval inat models. If false, eval imagenet models')
parser.add_argument('-rv', '--reverse', default=False, action='store_true',
                    help='Go over the models in reverse')
parser.add_argument('-rh', '--rebuild-hier', default=True, action='store_true',
                    help='Build hierarchy. For first run, set to true and later load it from file.')
parser.add_argument('-lh', '--load-hier', default=True, action='store_true',
                    help='Load hierarchy from file.')


metrics_names = ['accuracy', 'risk', 'coverage', 'avg_height']
threshold_algs = ['OTA', 'DARTS', 'CRC', 'CRC_01']

def get_hierarchy(rebuild_hier=False, load_hier=True, path='resources/imagenet1k_hier.pkl'):
    hierarchy = Hierarchy()
    if rebuild_hier:
        hierarchy.build_imagenet_tree()
        hierarchy.save_to_file(path)
    if load_hier:
        hierarchy = hierarchy.load_from_file(path)
    return hierarchy

def optimal_threshold_algorithm(hierarchy, y_scores_cal, y_true_cal, alpha=0.1):
    climb_inf_rule = get_inference_rule('Climbing', hierarchy)
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    preds_leaf_cal = y_scores_cal.max(dim=1)[1]
    correct_thetas = climb_inf_rule.get_tight_thresholds(all_nodes_probs_cal, preds_leaf_cal, y_true_cal)
    return climb_inf_rule.compute_quantile_threshold(correct_thetas, alpha=alpha)

def DARTS(hierarchy, y_scores_cal, y_true_cal, epsilon=0.1):
    # the reward for each node is: coverage * root entropy
    root_entropy = np.log2(hierarchy.num_leaves)
    rewards = hierarchy.coverage_vec * root_entropy
    # Step 1+2: get probabilities for all nodes and sum them upwards
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    # Step 3+4: init f_0, if its accuracy suffices then return it
    f_0_scores = rewards * all_nodes_probs_cal
    f_0_preds = f_0_scores.max(dim=0)[1]
    f_0_correctness = hierarchy.correctness(f_0_preds, y_true_cal).cpu()
    f_0_accuracy = f_0_correctness.sum().item() / len(f_0_correctness)
    if f_0_accuracy >= 1-epsilon:
        return 0
    # Step 5: calculate lambda bar
    r_max = rewards.max()
    r_root = rewards[hierarchy.root_index]
    lambda_bar = (r_max * (1-epsilon) - r_root) / epsilon
    # Step 6: binary search for optimal lambda
    min_lambda = 0
    max_lambda = lambda_bar.item()
    iteration_limit = 25
    confidence = 0.95
    desired_alpha = (1 - confidence) * 2
    num_examples = len(f_0_preds)
    for t in range(iteration_limit):
        lambda_t = (min_lambda + max_lambda) / 2
        f_t_scores = (rewards + lambda_t) * all_nodes_probs_cal
        f_t_preds = f_t_scores.max(dim=0)[1]
        f_t_correctness = hierarchy.correctness(f_t_preds, y_true_cal).cpu()
        f_t_accuracy = f_t_correctness.sum().item() / len(f_t_correctness)
        acc_bounds = stats.binom.interval(1-desired_alpha, num_examples, f_t_accuracy)
        acc_lower_bound = acc_bounds[0] / num_examples
        if acc_lower_bound > 1-epsilon:
            max_lambda = lambda_t
        else:
            min_lambda = lambda_t
    return max_lambda

# assuming the loss is monotonous
def conformal_risk_control(hierarchy, y_scores_cal, y_true_cal, alpha, B=1, loss='graph_distance'):
    lambdas = np.linspace(0, 1, 1001)
    all_nodes_probs_cal = hierarchy.all_nodes_probs(y_scores_cal)
    preds_leaf_cal = y_scores_cal.max(dim=1)[1]
    n = y_true_cal.shape[0]
    if loss == 'graph_distance':
        loss = GraphDistanceLoss(hierarchy)
        for lhat_idx, lam in enumerate(lambdas):
            _, hier_preds = climb_inf_rule.predict(all_nodes_probs_cal, preds_leaf_cal, lam)
            rhat = loss(hier_preds, y_true_cal)
            if (n/(n+1)) * rhat + B/(n+1) <= alpha:
                break
    elif loss == '01':
        for lhat_idx, lam in enumerate(lambdas):
            _, hier_preds = climb_inf_rule.predict(all_nodes_probs_cal, preds_leaf_cal, lam)
            rhat = 1 - hierarchy.correctness(hier_preds, y_true_cal).float().mean()
            if (n/(n+1)) * rhat + B/(n+1) <= alpha:
                break
    lhat_idx = max(lhat_idx - 1, 0) # Can't be -1.
    return lambdas[lhat_idx]
def validation(alg, hierarchy, y_scores_val, y_true_val, opt_result):
    preds_leaf_val = y_scores_val.max(dim=1)[1]
    all_nodes_probs_val = hierarchy.all_nodes_probs(y_scores_val)

    results = {}
    if alg == 'DARTS':
        opt_lambda = opt_result
        root_entropy = np.log2(hierarchy.num_leaves)
        rewards = hierarchy.coverage_vec * root_entropy
        probs = (rewards + opt_lambda) * all_nodes_probs_val
        preds = probs.max(dim=0)[1]
    else:
        opt_theta = opt_result
        _, preds = climb_inf_rule.predict(all_nodes_probs_val, preds_leaf_val, opt_theta)

    hier_correctness = hierarchy.correctness(preds, y_true_val).cpu()
    results['hier_accuracy'] = hier_correctness.sum().item() / len(hier_correctness)
    results['coverage'] = hierarchy.coverage(preds)
    return results


if __name__ == '__main__':
    args = parser.parse_args()
    if args.inat:
        path = 'resources/inat21.pkl'
        models_list_path = './models_lists/inat_models_list.txt'
    else:
        path = 'resources/imagenet1k_hier.pkl'
        models_list_path = './models_lists/imagenet_models_list.txt'
    torch.manual_seed(0)
    np.random.seed(0)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(0)
    hierarchy = get_hierarchy(rebuild_hier=args.rebuild_hier, load_hier=args.load_hier, path=path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hierarchy.anc_matrix = hierarchy.anc_matrix.to(device).float()

    model_names = []
    with open(models_list_path, 'r') as f:
        for line in f:
            model_names.append(line.strip())
    
    if args.low_limit is not None and args.high_limit is not None:
        model_names = model_names[args.low_limit:args.high_limit]
    if args.reverse:
        model_names = model_names[::-1]

    alpha_vals = [0.005, 0.01, 0.05, 0.1, 0.15, 0.2, 0.3]
    n = args.cal_set_n
    n_repeats = args.repeats
    save_full_model_results = False
    mean_results = []
    if args.inat:
        mean_results_file_name = f'results/thresholds/n{n}_reps{n_repeats}_inat.csv'
    else:
        mean_results_file_name = f'results/thresholds/n{n}_reps{n_repeats}_0905.csv'

    num_models = len(model_names)
    for i, model_name in enumerate(model_names):
        print(f'model: {model_name}, {i+1}/{num_models}')
        model_results = []
        model_results_file_name = f'results/thresholds/{model_name.replace("/", "-")}_n{n}_reps{n_repeats}.csv'
        try:
            all_y_scores = torch.load(f'resources/models_y_scores/{model_name}.pt').cuda()
            all_y_true = torch.load(f'resources/models_ground_truth/{model_name}.pt').cuda()
            for alpha in alpha_vals:
                # if there's a line in mean_results_file_name skip it
                row_exists = False
                if os.path.exists(mean_results_file_name) and len(model_names) > 1:
                    with open(mean_results_file_name, 'r') as f:
                        reader = csv.reader(f)
                        for row in reader:
                            if row[1] == model_name and row[3] == str(100*(1-alpha)):
                                row_exists = True
                                break
                if row_exists:
                    continue
                print(f'alpha: {alpha}')
                results = []
                alpha_results = {a:{'Optimal Param': [], 'Accuracy':[], 'Coverage':[]} for a in threshold_algs}
                for rep in range(n_repeats):
                    # each rep produces a random calibration set (reprodicible across runs)
                    cal_indices, val_indices = train_test_split(np.arange(len(all_y_true)), train_size=n, stratify=all_y_true.cpu())
                    y_scores_cal = all_y_scores[cal_indices].cuda()
                    y_scores_val = all_y_scores[val_indices].cuda()
                    y_true_cal = all_y_true[cal_indices].long().cuda()
                    y_true_val = all_y_true[val_indices].long().cuda()
                    if args.inat:
                        # split the val set to 2 parts
                        y_scores_val_1, y_scores_val_2 = torch.split(y_scores_val, y_scores_val.shape[0] // 2)
                        y_true_val_1, y_true_val_2 = torch.split(y_true_val, y_true_val.shape[0] // 2)
                                            
                    # OTA
                    opt_theta = optimal_threshold_algorithm(hierarchy, y_scores_cal, y_true_cal, alpha=alpha)
                    alpha_results['OTA']['Optimal Param'].append(opt_theta)
                    if args.inat:
                        res_1 = validation('OTA', hierarchy, y_scores_val_1, y_true_val_1, opt_theta)
                        res_2 = validation('OTA', hierarchy, y_scores_val_2, y_true_val_2, opt_theta)
                        alpha_results['OTA']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                        alpha_results['OTA']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
                    else:
                        res = validation('OTA', hierarchy, y_scores_val, y_true_val, opt_theta)
                        alpha_results['OTA']['Accuracy'].append(res['hier_accuracy'])
                        alpha_results['OTA']['Coverage'].append(res['coverage'])

                    # DARTS
                    opt_lambda = DARTS(hierarchy, y_scores_cal, y_true_cal, epsilon=alpha)
                    alpha_results['DARTS']['Optimal Param'].append(opt_lambda)
                    if args.inat:
                        res_1 = validation('DARTS', hierarchy, y_scores_val_1, y_true_val_1, opt_lambda)
                        res_2 = validation('DARTS', hierarchy, y_scores_val_2, y_true_val_2, opt_lambda)
                        alpha_results['DARTS']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                        alpha_results['DARTS']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
                    else:
                        res = validation('DARTS', hierarchy, y_scores_val, y_true_val, opt_lambda)
                        alpha_results['DARTS']['Accuracy'].append(res['hier_accuracy'])
                        alpha_results['DARTS']['Coverage'].append(res['coverage'])
                
                    # CRC
                    opt_lambda = conformal_risk_control(hierarchy, y_scores_cal, y_true_cal, alpha=alpha)
                    alpha_results['CRC']['Optimal Param'].append(opt_lambda)
                    if args.inat:
                        res_1 = validation('CRC', hierarchy, y_scores_val_1, y_true_val_1, opt_theta)
                        res_2 = validation('CRC', hierarchy, y_scores_val_2, y_true_val_2, opt_theta)
                        alpha_results['CRC']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                        alpha_results['CRC']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
                    else:
                        res = validation('CRC', hierarchy, y_scores_val, y_true_val, opt_lambda)
                        alpha_results['CRC']['Accuracy'].append(res['hier_accuracy'])
                        alpha_results['CRC']['Coverage'].append(res['coverage'])
                    
                    # CRC_01
                    opt_lambda = conformal_risk_control(hierarchy, y_scores_cal, y_true_cal, alpha=alpha, loss='01')
                    alpha_results['CRC_01']['Optimal Param'].append(opt_lambda)
                    if args.inat:
                        res_1 = validation('CRC_01', hierarchy, y_scores_val_1, y_true_val_1, opt_theta)
                        res_2 = validation('CRC_01', hierarchy, y_scores_val_2, y_true_val_2, opt_theta)
                        alpha_results['CRC_01']['Accuracy'].append((res_1['hier_accuracy'] + res_2['hier_accuracy'])/2)
                        alpha_results['CRC_01']['Coverage'].append((res_1['coverage'] + res_2['coverage'])/2)
                    else:
                        res = validation('CRC_01', hierarchy, y_scores_val, y_true_val, opt_lambda)
                        alpha_results['CRC_01']['Accuracy'].append(res['hier_accuracy'])
                        alpha_results['CRC_01']['Coverage'].append(res['coverage'])
                    
                for alg in threshold_algs:
                    row = {}
                    row['Architecture'] = model_name
                    row['Algorithm'] = alg
                    row['Target Accuracy'] = 100*(1-alpha)
                    row['Optimal Param Result (mean)'] = np.mean([r for r in alpha_results[alg]['Optimal Param']])
                    row['Accuracy (mean)'] = np.mean([100*r for r in alpha_results[alg]['Accuracy']])
                    row['Accuracy (std)'] = np.std([100*r for r in alpha_results[alg]['Accuracy']])
                    row['Accuracy Error (mean)'] = np.mean([100*(r-(1-alpha)) for r in alpha_results[alg]['Accuracy']])
                    row['Accuracy Error (std)'] = np.std([100*(1-alpha)-100*r for r in alpha_results[alg]['Accuracy']])
                    row['Accuracy Error Abs (mean)'] = np.mean(np.abs([100*(1-alpha)-100*r for r in alpha_results[alg]['Accuracy']]))
                    row['Accuracy Error Abs (std)'] = np.std(np.abs([100*(1-alpha)-100*r for r in alpha_results[alg]['Accuracy']]))
                    row['Coverage (mean)'] = np.mean([100*r for r in alpha_results[alg]['Coverage']])
                    row['Coverage (std)'] = np.std([100*r for r in alpha_results[alg]['Coverage']])
                    model_results.append(row)
            
                # save model results to csv
                if save_full_model_results:
                    model_results_df = pd.DataFrame(model_results)
                    with open(model_results_file_name, 'a') as f:
                        model_results_df.to_csv(f, header=f.tell()==0)

            results_df = pd.DataFrame(model_results)
            # add to csv without overwriting existing file
            with open(mean_results_file_name, 'a') as f:
                results_df.to_csv(f, header=f.tell()==0)


        except Exception as e:
            print(f'Failed. model {model_name}. Error: {e}')
            with open('./resources/models_lists/failed_models.txt', 'a') as f:
                f.write(model_name + '\n')
    print('done')


